package concurrency;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author Mr.Sun
 * @date 2022年09月04日 15:27
 *
 * 临界区 测试
 * <p>
 *     演示了如何把一个非线程安全的类，在其他类的保护和控制之下，应用于多线程的环境
 * </p>
 */
public class CriticalSelection {
    public static void main(String[] args) {
        PairManager pman1 = new PairManager1(),
                pman2 = new PairManager2();

        testApproaches(pman1, pman2);
    }

    // 测试两种不同的方法
    static void testApproaches(PairManager pman1, PairManager pman2) {
        ExecutorService exec = Executors.newCachedThreadPool();
        PairManipulator
                pm1 = new PairManipulator(pman1),
                pm2 = new PairManipulator(pman2);

        PairChecker
                pcheck1 = new PairChecker(pman1),
                pcheck2 = new PairChecker(pman2);

        exec.execute(pm1);
        exec.execute(pm2);

        exec.execute(pcheck1);
        exec.execute(pcheck2);

        try {
            TimeUnit.MILLISECONDS.sleep(200);
        } catch(InterruptedException e) {
            System.out.println("Sleep interrupted");
        } finally {
            exec.shutdown();
        }
        System.out.println("pm1: " + pm1 + "\npm2: " + pm2);
        System.exit(0);
    }
}

class PairChecker implements Runnable {
    private PairManager pm;
    public PairChecker(PairManager pm) {
        this.pm = pm;
    }

    @Override
    public void run() {
        while(true) {
            pm.checkCounter.incrementAndGet();
            pm.getPair().checkState();
        }
    }
}

class PairManipulator implements Runnable {
    private PairManager pm;

    public PairManipulator(PairManager pm) {
        this.pm = pm;
    }

    @Override
    public void run() {
        while(true) {
            pm.increment();
        }
    }

    public String toString() {
        return "Pair: " + pm.getPair() + " checkCounter = " + pm.checkCounter.get();
    }
}

// 线程安全的Pair
abstract class PairManager {
    AtomicInteger checkCounter = new AtomicInteger(0);
    protected Pair p = new Pair();
    private List<Pair> storage = Collections.synchronizedList(new ArrayList<>());

    public synchronized Pair getPair() {
        // 复制一份以确保原件的安全:
        return new Pair(p.getX(), p.getY());
    }

    // 假设这是一个耗时的操作
    protected void store(Pair p) {
        storage.add(p);
        try {
            TimeUnit.MILLISECONDS.sleep(50);
        } catch (InterruptedException ignore) {}
    }

    public abstract void increment();
}

// 同步整个方法
class PairManager1 extends PairManager {

    @Override
    public synchronized void increment() {
        p.incrementX();
        p.incrementY();
        store(getPair());
    }
}

// 同步临界区代码块
class PairManager2 extends PairManager {

    @Override
    public void increment() {
        Pair temp;
        synchronized(this) {
            p.incrementX();
            p.incrementY();
            temp = getPair();
        }
        store(temp);
    }
}

// 非线程安全
class Pair {
    private int x, y;

    public Pair(int x, int y) {
        this.x = x;
        this.y = y;
    }

    public Pair() {
        this(0, 0);
    }

    public int getX() {
        return x;
    }

    public int getY() {
        return y;
    }

    public void incrementX () {
        // 自增操作不是线程安全的
        x++;
    }

    public void incrementY () {
        // 自增操作不是线程安全的
        y++;
    }

    @Override
    public String toString() {
        return "x: " + x + ", y: " + y ;
    }

    public class PairValueNotEqualException extends RuntimeException {
        public PairValueNotEqualException() {
            super("一对不相等的值：" + Pair.this);
        }
    }

    // 任意不变量 —— 两个变量必须相等
    public void checkState() {
        if(x != y) {
            throw new PairValueNotEqualException();
        }
    }
}
